package com.coolingme.server;

import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * WebSocket核心类
 *
 * @author wangyue
 * @date 2020/10/16 10:35
 */
@ServerEndpoint("/webSocket/{sid}")
@Component
public class WebSocketServer {


    /**
     * 静态变量，用来记录当前在线连接数。应该把它设计成线程安全的。
     */
    private static AtomicInteger onlineNum = new AtomicInteger();

    /**
     * concurrent包的线程安全Set，用来存放每个客户端对应的WebSocketServer对象
     */
    private static ConcurrentHashMap<String, Session> sessionPools = new ConcurrentHashMap<>();

    /**
     * 发送消息
     *
     * @param session
     * @param message
     * @throws IOException
     */
    public void sendMessage(Session session, String message) throws IOException {
        if (session != null) {
            synchronized (session) {
                System.out.println("Server：" + message);
                session.getBasicRemote().sendText(message);
            }
        }
    }

    /**
     * 给指定用户发送信息
     *
     * @param userName
     * @param message
     */
    public void sendInfo(String userName, String message) {
        Session session = sessionPools.get(userName);
        try {
            sendMessage(session, message);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 群发信息
     *
     * @param message
     */
    public void sendGroupInfo(String message) {
        for (Session s : sessionPools.values()) {
            try {
                sendMessage(s, message);
            } catch (Exception e) {
                e.printStackTrace();
                continue;
            }
        }
    }


    /**
     * 建立连接成功调用
     *
     * @param session
     * @param userName
     */
    @OnOpen
    public void onOpen(Session session, @PathParam(value = "sid") String userName) {
        sessionPools.put(userName, session);
        addOnlineCount();
        System.out.println(userName + "加入webSocket！当前人数为" + onlineNum);
        try {
            sendMessage(session, "欢迎" + userName + "加入连接！");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }


    /**
     * 关闭连接时调用
     *
     * @param userName
     */
    @OnClose
    public void onClose(@PathParam(value = "sid") String userName) {
        sessionPools.remove(userName);
        subOnlineCount();
        System.out.println(userName + "断开webSocket连接！当前人数为" + onlineNum);
    }


    /**
     * 收到客户端信息
     *
     * @param message
     * @throws IOException
     */
    @OnMessage
    public void onMessage(Session session, String message) throws IOException {
//        message = "客户端：" + message + ",已收到";
        System.out.println("Client：" + message + "               |               " + session);
//        if ("hello".equals(message)) {

        if (message.startsWith("hello")) {
            sendMessage(session, "hi                    |" + session);
        } else {
            sendGroupInfo("hi");
        }

    }


    /**
     * 错误时调用
     *
     * @param session
     * @param throwable
     */
    @OnError
    public void onError(Session session, Throwable throwable) {
        System.err.println("Server Error");
        throwable.printStackTrace();
    }

    /**
     * 在线人数+1
     */
    public static void addOnlineCount() {
        onlineNum.incrementAndGet();
    }

    /**
     * 在线人数-1
     */
    public static void subOnlineCount() {
        onlineNum.decrementAndGet();
    }

}
